Coin toss results¶

In [1]:
import pickle
import jax

import matplotlib.pyplot as plt
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from scipy.stats import gaussian_kde
import plotly.express as px
import pandas as pd
import pickle
tfd = tfp.distributions
import plotly
plotly.offline.init_notebook_mode()
In [2]:
def get_data(file_name):
    with open(file_name,'rb') as f:
        all_data = pickle.load(f)
    samples = all_data['samples']
    alpha_prior = all_data['prior']['alpha']
    beta_prior = all_data['prior']['beta']
    return samples, alpha_prior, beta_prior
In [3]:
samples, alpha_prior, beta_prior = get_data("../../data/coin_toss")
plt.hist(samples)
plt.ylabel("frequency")
plt.title("Given Data")
plt.show()
In [4]:
samples,alpha_prior,beta_prior = get_data("../../data/coin_toss")
x = jnp.linspace(0.01,0.99,100)
one= jnp.sum(samples==1).astype('float32')
zero= jnp.sum(samples==0).astype('float32')
true_post_dist = tfd.Beta(alpha_prior+one,beta_prior+zero)
true_post_pdf = true_post_dist.prob(x)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
In [5]:
with open('results_data/coin_toss_VI_Ajax_result','rb') as f:
    variational  = pickle.load(f)
ajax_vi_pdf = jnp.exp(variational.log_prob({"theta":x}))
In [6]:
with open('results_data/MCMC_BlackJAX','rb') as black_f:
    black_samples = pickle.load(black_f)
kde_black = gaussian_kde(black_samples,bw_method=0.3)
pdf_black = kde_black(x)
In [7]:
with open('results_data/coin_toss_laplace_result','rb') as f:
    laplace_normal  = pickle.load(f)
laplace_pdf = laplace_normal.prob(x)
In [8]:
def get_likelihood(params, aux=None):
    return tfd.Bernoulli(probs=params['p_of_h'])
laplax_dict = pd.read_pickle('results_data/laplax_coin_toss')
laplax_posterior = laplax_dict['model'].apply(laplax_dict['params'], laplax_dict['data'])
laplax_pdf = jnp.exp(laplax_posterior.log_prob({'p_of_h': x}))
tfp.distributions.Normal("Normal", batch_shape=[1], event_shape=[], dtype=float32) (100,)
In [9]:
all_pdfs = jnp.array([true_post_pdf,ajax_vi_pdf,pdf_black,laplace_pdf, laplax_pdf]).reshape((-1))
all_labels = ["True Posterior"]*x.shape[0]+["Ajax VI estimate"]*x.shape[0]+["Blackjax rmh estimate"]*x.shape[0]+["Laplace approximation"]*x.shape[0]+["Laplax"]*x.shape[0]
x_repeated = jnp.tile(x,5)
to_df = {
    "theta":x_repeated,
    "PDF":all_pdfs,
    "label": all_labels

}
df = pd.DataFrame(to_df)

fig = px.line(to_df,"theta","PDF",color="label",title="Coin toss posterior") 
fig.show()
In [10]:
import plotly
plotly.offline.init_notebook_mode()